{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "1f242922",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['S字母:5278=wqqqrqqqrqqqrqqqwE',\n",
       " 'S大写:9113=叁陆肆伍伍陆肆伍伍陆肆伍伍陆肆伍贰E',\n",
       " 'S小写:1255=五〇二〇五〇二〇五〇二〇五〇二〇EP',\n",
       " 'S大写:1198=肆柒玖贰肆柒玖贰肆柒玖贰肆柒玖贰EP',\n",
       " 'S大写:3751=壹伍零零伍伍零零伍伍零零伍伍零零肆E',\n",
       " 'S大写:5649=贰贰伍玖捌贰伍玖捌贰伍玖捌贰伍玖陆E',\n",
       " 'S小写:6892=二七五七〇七五七〇七五七〇七五六八E',\n",
       " 'S小写:2195=八七八〇八七八〇八七八〇八七八〇EP',\n",
       " 'S字母:6627=wytqpytqpytqpytpiE',\n",
       " 'S大写:7516=叁零零陆柒零零陆柒零零陆柒零零陆肆E']"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "%run common.ipynb\n",
    "\n",
    "[tokenizer.decode(i) for i in tokenizer.get_batch_data(prefix=True)[1]][:10]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "e7063d80",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(torch.Size([64, 9]), torch.Size([64, 18]))"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "@torch.no_grad()\n",
    "def get_question_and_answer():\n",
    "    _, token, _ = tokenizer.get_batch_data(prefix=True)\n",
    "\n",
    "    split = [i.index(tokenizer.encoder['=']) + 1 for i in token]\n",
    "\n",
    "    #只要问题部分,等号后面的内容切除\n",
    "    question = [t[:s] for t, s in zip(token, split)]\n",
    "    answer = [t[s:] for t, s in zip(token, split)]\n",
    "\n",
    "    #统一长度\n",
    "    lens = max([len(i) for i in question])\n",
    "    question = [[tokenizer.encoder['P']] * (lens - len(i)) + i\n",
    "                for i in question]\n",
    "    question = torch.LongTensor(question).to(device)\n",
    "\n",
    "    lens = max([len(i) for i in answer])\n",
    "    answer = [[tokenizer.encoder['P']] * (lens - len(i)) + i for i in answer]\n",
    "    answer = torch.LongTensor(answer).to(device)\n",
    "\n",
    "    return question, answer\n",
    "\n",
    "\n",
    "question, answer = get_question_and_answer()\n",
    "\n",
    "question.shape, answer.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "972f2d2e",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([64, 18])"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model_ppo = torch.load('ppo.model')\n",
    "model_ppo.to(device)\n",
    "model_ppo.eval()\n",
    "\n",
    "predict = generate(model_ppo.model_gen, question)\n",
    "predict = predict[:, question.shape[1]:]\n",
    "\n",
    "predict.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "d4468184",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "S大写:7664= 叁零陆伍玖零陆伍玖零陆伍玖零陆伍陆E 叁零陆伍玖零陆伍玖零陆伍玖零陆伍陆E\n",
      "S小写:9893= 三九五七五九五七五九五七五九五七二E 三九五七五九五七五九五七五九五七二E\n",
      "S小写:6529= 二六一一八六一一八六一一八六一一六E 二六一一八六一一八六一一八六一一六E\n",
      "S小写:8139= 三二五五九二五五九二五五九二五五六E 三二五五九二五五九二五五九二五五六E\n",
      "S大写:7086= 贰捌叁肆陆捌叁肆陆捌叁肆陆捌叁肆肆E 贰捌叁肆陆捌叁肆陆捌叁肆陆捌叁肆肆E\n",
      "S大写:5099= 贰零叁玖捌零叁玖捌零叁玖捌零叁玖陆E 贰零叁玖捌零叁玖捌零叁玖捌零叁玖陆E\n",
      "S小写:9199= 三六七九九六七九九六七九九六七九六E 三六七九九六七九九六七九九六七九六E\n",
      "S大写:6229= 贰肆玖壹捌肆玖壹捌肆玖壹捌肆玖壹陆E 贰肆玖壹捌肆玖壹捌肆玖壹捌肆玖壹陆E\n",
      "S小写:6546= 二六一八六六一八六六一八六六一八四E 二六一八六六一八六六一八六六一八四E\n",
      "S数字:4435= 17741774177417740E 17741774177417740E\n",
      "S字母:3113= qwrtewrtewrtewrtwE qwrtewrtewrtewrtwE\n",
      "S小写:1271= 五〇八四五〇八四五〇八四五〇八四E 四〇八四四〇八四四〇八四四〇八四E\n",
      "S小写:1424= 五六九六五六九六五六九六五六九六E 五六九六五六九六五六九六五六九六E\n",
      "S大写:4742= 壹捌玖陆玖捌玖陆玖捌玖陆玖捌玖陆捌E 壹捌玖陆玖捌玖陆玖捌玖陆玖捌玖陆捌E\n",
      "S数字:3692= 14769476947694768E 14769476947694768E\n",
      "S数字:7171= 28686868686868684E 28686868686868684E\n",
      "S大写:2777= 壹壹壹零玖壹壹零玖壹壹零玖壹壹零捌E 壹壹壹零玖壹壹零玖壹壹零玖壹壹零捌E\n",
      "S小写:9911= 三九六四七九六四七九六四七九六四四E 三九六四七九六四七九六四七九六四四E\n",
      "S数字:4613= 18453845384538452E 18453845384538452E\n",
      "S大写:2593= 壹零叁柒叁零叁柒叁零叁柒叁零叁柒贰E 壹零叁柒叁零叁柒叁零叁柒叁零叁柒贰E\n",
      "S小写:9875= 三九五〇三九五〇三九五〇三九五〇〇E 三九五〇三九五〇三九五〇三九五〇〇E\n",
      "S数字:2299= 9196919691969196E 9196919691969196E\n",
      "S小写:2626= 一〇五〇五〇五〇五〇五〇五〇五〇四E 一〇五〇五〇五〇五〇五〇五〇五〇四E\n",
      "S小写:4443= 一七七七三七七七三七七七三七七七二E 一七七七三七七七三七七七三七七七二E\n",
      "S小写:2786= 一一一四五一一四五一一四五一一四四E 一一一四五一一四五一一四五一一四四E\n",
      "S字母:1077= repirepirepirepiE repirepirepirepiE\n",
      "S小写:8809= 三五二三九五二三九五二三九五二三六E 三五二三九五二三九五二三九五二三六E\n",
      "S小写:6800= 二七二〇二七二〇二七二〇二七二〇〇E 二七二〇二七二〇二七二〇二七二〇〇E\n",
      "S字母:7953= eqiqtqiqtqiqtqiqwE eqiqtqiqtqiqtqiqwE\n",
      "S小写:5907= 二三六三〇三六三〇三六三〇三六二八E 二三六三〇三六三〇三六三〇三六二八E\n",
      "S数字:6032= 24130413041304128E 24130413041304128E\n",
      "S大写:9967= 叁玖捌柒壹玖捌柒壹玖捌柒壹玖捌陆捌E 叁玖捌柒壹玖捌柒壹玖捌柒壹玖捌陆捌E\n",
      "S大写:9511= 叁捌零肆柒捌零肆柒捌零肆柒捌零肆肆E 叁捌零肆柒捌零肆柒捌零肆柒捌零肆肆E\n",
      "S字母:8200= ewipewipewipewippE ewipewipewipewippE\n",
      "S字母:5950= weipweipweipweippE weipweipweipweippE\n",
      "S小写:6737= 二六九五〇六九五〇六九五〇六九四八E 二六九五〇六九五〇六九五〇六九四八E\n",
      "S字母:3734= qroeuroeuroeuroeyE qroeuroeuroeuroeyE\n",
      "S数字:2278= 9112911291129112E 9112911291129112E\n",
      "S数字:8039= 32159215921592156E 32159215921592156E\n",
      "S字母:9799= eoqoooqoooqoooqoyE eoqoooqoooqoooqoyE\n",
      "S数字:2980= 11921192119211920E 11921192119211920E\n",
      "S字母:7377= wotqpotqpotqpotpiE wotqpotqpotqpotpiE\n",
      "S小写:5406= 二一六二六一六二六一六二六一六二四E 二一六二六一六二六一六二六一六二四E\n",
      "S小写:6428= 二五七一四五七一四五七一四五七一二E 二五七一四五七一四五七一四五七一二E\n",
      "S字母:2509= qppeuppeuppeuppeyE qppeuppeuppeuppeyE\n",
      "S大写:5187= 贰零柒伍零零柒伍零零柒伍零零柒肆捌E 贰零柒伍零零柒伍零零柒伍零零柒肆捌E\n",
      "S字母:7275= woqpwoqpwoqpwoqppE woqpwoqpwoqpwoqppE\n",
      "S数字:1783= 7132713271327132E 7132713271327132E\n",
      "S字母:2144= ituyituyituyituyE ituyituyituyituyE\n",
      "S字母:6048= wrqorrqorrqorrqowE wrqorrqorrqorrqowE\n",
      "S大写:1813= 柒贰伍贰柒贰伍贰柒贰伍贰柒贰伍贰E 柒贰伍贰柒贰伍贰柒贰伍贰柒贰伍贰E\n",
      "S小写:7707= 三〇八三一〇八三一〇八三一〇八二八E 三〇八三一〇八三一〇八三一〇八二八E\n",
      "S大写:8436= 叁叁柒肆柒叁柒肆柒叁柒肆柒叁柒肆肆E 叁叁柒肆柒叁柒肆柒叁柒肆柒叁柒肆肆E\n",
      "S小写:9010= 三六〇四三六〇四三六〇四三六〇四〇E 三六〇四三六〇四三六〇四三六〇四〇E\n",
      "S大写:1813= 柒贰伍贰柒贰伍贰柒贰伍贰柒贰伍贰E 柒贰伍贰柒贰伍贰柒贰伍贰柒贰伍贰E\n",
      "S数字:5718= 22874287428742872E 22874287428742872E\n",
      "S小写:4968= 一九八七三九八七三九八七三九八七二E 一九八七三九八七三九八七三九八七二E\n",
      "S数字:5132= 20530053005300528E 20530053005300528E\n",
      "S数字:6421= 25686568656865684E 25686568656865684E\n",
      "S小写:5481= 二一九二六一九二六一九二六一九二四E 二一九二六一九二六一九二六一九二四E\n",
      "S大写:9309= 叁柒贰叁玖柒贰叁玖柒贰叁玖柒贰叁陆E 叁柒贰叁玖柒贰叁玖柒贰叁玖柒贰叁陆E\n",
      "S数字:4044= 16177617761776176E 16177617761776176E\n",
      "S数字:6370= 25482548254825480E 25482548254825480E\n",
      "S小写:5980= 二三九二二三九二二三九二二三九二〇E 二三九二二三九二二三九二二三九二〇E\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "0.984375"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "correct = 0\n",
    "for q, a, p in zip(question, answer, predict):\n",
    "    q, a, p = q.tolist(), a.tolist(), p.tolist()\n",
    "\n",
    "    if tokenizer.encoder['E'] in a:\n",
    "        split = a.index(tokenizer.encoder['E']) + 1\n",
    "        a = a[:split]\n",
    "\n",
    "    if tokenizer.encoder['E'] in p:\n",
    "        split = p.index(tokenizer.encoder['E']) + 1\n",
    "        p = p[:split]\n",
    "\n",
    "    q, a, p = tokenizer.decode(q), tokenizer.decode(a), tokenizer.decode(p)\n",
    "\n",
    "    print(q, a, p)\n",
    "\n",
    "    correct += a == p\n",
    "\n",
    "correct / len(answer)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:cuda117]",
   "language": "python",
   "name": "conda-env-cuda117-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
